package com.chatplus.application.aiprocessor.handler;

import com.chatplus.application.aiprocessor.util.WebSocketManager;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.dto.ws.WsChatMessage;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.enumeration.MessageTypeEnum;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.io.IOException;
import java.util.Optional;

/**
 * SD 处理器
 * 基于Spring WebSocket
 */
@Component
public class SdWebSocketHandler extends AbstractWebSocketHandler {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(SdWebSocketHandler.class);
    //在线总数
    private static int onlineCount;
    //用户ID
    private Long userId;
    private final ObjectMapper objectMapper;

    private static final String SD_WS_PREFIX = "SD:WS:";

    private String sessionId;

    @Autowired
    public SdWebSocketHandler(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    /**
     * 获取当前连接数
     */
    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    /**
     * 当前连接数加一
     */
    public static synchronized void addOnlineCount() {
        SdWebSocketHandler.onlineCount++;
    }

    /**
     * 当前连接数减一
     */
    public static synchronized void subOnlineCount() {
        SdWebSocketHandler.onlineCount--;
    }

    /**
     * 会话连接成功后
     */
    @Override
    public void afterConnectionEstablished(@NotNull WebSocketSession session) {
        ChatWebSocketRequest chatWebSocketRequest = (ChatWebSocketRequest)
                session.getAttributes().get(ChatWebSocketRequest.SD_URL_PATH);
        this.userId = chatWebSocketRequest.getUserId();
        this.sessionId = SD_WS_PREFIX + userId;
        addOnlineCount();
        WebSocketManager.add(sessionId, session);
        // 会话上下文处理
        LOGGER.message("SD建立连接").context("连接ID", this.userId)
                .context("当前连接数", Optional.of(getOnlineCount())).info();
    }

    /**
     * 处理会话发送来的消息
     */
    @Override
    public void handleMessage(@NotNull WebSocketSession session, @NotNull WebSocketMessage message) {
        if (message.getPayloadLength() == 0) {
            return;
        }
        WsChatMessage wsChatMessage;
        try {
            wsChatMessage = objectMapper.readValue(message.getPayload().toString(), WsChatMessage.class);
        } catch (Exception e) {
            wsChatMessage = new WsChatMessage(MessageTypeEnum.CHAT, message.getPayload().toString());
        }
        try {
            // 心跳消息
            if (MessageTypeEnum.HEARTBEAT.equals(wsChatMessage.getType())) {
                return;
            }
        } catch (Exception e) {
            LOGGER.message("收到消息处理异常").exception(e).error();
        }
    }

    /**
     * 会话连接发送错误
     */
    @Override
    public void handleTransportError(WebSocketSession session, @NotNull Throwable exception) throws IOException {
        WebSocketManager.removeAndClose(sessionId);
        LOGGER.message("连接发送错误").context("连接信息", session.getAttributes()).context("exception", exception.getMessage()).error();
    }

    /**
     * 会话连接关闭后
     */
    @Override
    public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws IOException {
        subOnlineCount();
        WebSocketManager.removeAndClose(sessionId);
        LOGGER.message("断开连接").context("连接用户", this.userId).context("当前连接数", getOnlineCount()).info();
    }

    public void sendTaskUpdatedMessage(Long userId, Integer progress) {
        try {
            WebSocketSession session = WebSocketManager.get(SD_WS_PREFIX + userId);
            if (session == null || !session.isOpen()) {
                LOGGER.message("session会话已经关闭").error();
                return;
            }
            if (progress == 100) {
                session.sendMessage(new BinaryMessage("finish".getBytes()));
            } else {
                session.sendMessage(new BinaryMessage("running".getBytes()));
            }
        } catch (IOException e) {
            LOGGER.message("发送SD画图状态变更消息异常").exception(e).error();
        }
    }
}
